"""
WS-DAN models
Hu et al.,
"See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification",
arXiv:1901.09891
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import cub_models.vgg as vgg
import cub_models.resnet as resnet
from cub_models.inception import inception_v3, BasicConv2d

import random

__all__ = ['Model_Wrapper']
EPSILON = 1e-6


def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.normal_(m.weight, std=0.001)
        if m.bias:
            nn.init.constant_(m.bias, 0.0)

def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
        nn.init.constant_(m.bias, 0.0)
    elif classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
    elif classname.find('BatchNorm') != -1:
        if m.affine:
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0.0)

# Bilinear Attention Pooling
class BAP(nn.Module):
    def __init__(self, pool='GAP'):
        super(BAP, self).__init__()
        assert pool in ['GAP', 'GMP']
        if pool == 'GAP':
            self.pool = None
        else:
            self.pool = nn.AdaptiveMaxPool2d(1)

    def forward(self, features, attentions):
        B, C, H, W = features.size()
        _, M, AH, AW = attentions.size()

        # match size
        if AH != H or AW != W:
            attentions = F.upsample_bilinear(attentions, size=(H, W))

        # feature_matrix: (B, M, C) -> (B, M * C)
        if self.pool is None:
            feature_matrix = (torch.einsum('imjk,injk->imn', (attentions, features)) / float(H * W)).view(B, -1)
            # feature_matrix = (torch.einsum('imjk,injk->imn', (attentions, features)) / float(H * W)).sum(dim=1) ## (B,M,C) --> (B,C)
        else:
            feature_matrix = []
            for i in range(M):
                AiF = self.pool(features * attentions[:, i:i + 1, ...]).view(B, -1)
                feature_matrix.append(AiF)
            feature_matrix = torch.cat(feature_matrix, dim=1)

        # sign-sqrt
        feature_matrix_raw = torch.sign(feature_matrix) * torch.sqrt(torch.abs(feature_matrix) + EPSILON)

        # l2 normalization along dimension M and C
        feature_matrix = F.normalize(feature_matrix_raw, dim=-1)
        # feature_matrix = feature_matrix_raw

        if self.training:
            fake_att = torch.zeros_like(attentions).uniform_(0, 2)
        else:
            fake_att = torch.ones_like(attentions)
        counterfactual_feature = (torch.einsum('imjk,injk->imn', (fake_att, features)) / float(H * W)).view(B, -1)
        # counterfactual_feature = (torch.einsum('imjk,injk->imn', (fake_att, features)) / float(H * W)).sum(dim=1)

        counterfactual_feature = torch.sign(counterfactual_feature) * torch.sqrt(torch.abs(counterfactual_feature) + EPSILON)

        counterfactual_feature = F.normalize(counterfactual_feature, dim=-1)
        return feature_matrix, counterfactual_feature


class Model_Wrapper(nn.Module):
    def __init__(self, num_classes, M=32, net='inception_mixed_6e', pretrained=False, pth_path=None):
        super(Model_Wrapper, self).__init__()
        self.num_classes = num_classes
        self.M = M
        self.net = net

        # Network Initialization
        if 'inception' in net:
            if net == 'inception_mixed_6e':
                self.features = inception_v3(pretrained=pretrained).get_features_mixed_6e()
                self.num_features = 768
            elif net == 'inception_mixed_7c':
                self.features = inception_v3(pretrained=pretrained).get_features_mixed_7c()
                self.num_features = 2048
            else:
                raise ValueError('Unsupported net: %s' % net)
        elif 'vgg' in net:
            self.features = getattr(vgg, net)(pretrained=pretrained).get_features()
            self.num_features = 512
        elif 'resnet' in net:
            self.features = getattr(resnet, net)(pretrained=pretrained, pth_path=pth_path)#.get_features()
            self.num_features = 512 * 4 #self.features[-1][-1].expansion
        elif 'att' in net:
            print('==> Using MANet with resnet101 backbone')
            self.features = MANet()
            self.num_features = 2048
        else:
            raise ValueError('Unsupported net: %s' % net)

        # # Attention Maps
        # self.attentions = BasicConv2d(self.num_features, self.M, kernel_size=1)

        # # Bilinear Attention Pooling
        # self.bap = BAP(pool='GAP')

        # self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(self.num_features, self.num_features//4)
        self.fc2 = nn.Linear(self.num_features//4, self.num_classes)


    def forward(self, x):
        fm, x1, _ = self.features(x)

        # pooling 
        # x1 = self.pool(fm)  # dim = 2048
        # x1 = x1.view(x1.size(0), -1)
        x2 = self.fc1(x1)  # dim = 512
        # Classification
        pred = self.fc2(x2)

        return pred, x1, x2
    
    def load_state_dict(self, state_dict, strict=True):
        model_dict = self.state_dict()
        pretrained_dict = {k: v for k, v in state_dict.items()
                           if k in model_dict and model_dict[k].size() == v.size()}

        if len(pretrained_dict) == len(state_dict):
            print('%s: All params loaded' % type(self).__name__)
        else:
            print('%s: Some params were not loaded:' % type(self).__name__)
            not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()]
            print(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys))

        model_dict.update(pretrained_dict)
        super(Model_Wrapper, self).load_state_dict(model_dict)
